Example of a Dirichlet Process Mixture Model clustering using Gaussians


In [1]:
using DataFrames

# number of samples per cluster
N = 100

# create clusters
Data = DataFrame( x = randn(N), y = randn(N), class = "cluster1" )

# append second cluster
append!(Data, DataFrame( x = randn(N) + 5, y = randn(N) + 5, class = "cluster2" ));

In [2]:
# visualize data
using Gadfly

plot(Data, x = :x, y = :y, color = :class)


Out[2]:
x -25 -20 -15 -10 -5 0 5 10 15 20 25 30 -20.0 -19.5 -19.0 -18.5 -18.0 -17.5 -17.0 -16.5 -16.0 -15.5 -15.0 -14.5 -14.0 -13.5 -13.0 -12.5 -12.0 -11.5 -11.0 -10.5 -10.0 -9.5 -9.0 -8.5 -8.0 -7.5 -7.0 -6.5 -6.0 -5.5 -5.0 -4.5 -4.0 -3.5 -3.0 -2.5 -2.0 -1.5 -1.0 -0.5 0.0 0.5 1.0 1.5 2.0 2.5 3.0 3.5 4.0 4.5 5.0 5.5 6.0 6.5 7.0 7.5 8.0 8.5 9.0 9.5 10.0 10.5 11.0 11.5 12.0 12.5 13.0 13.5 14.0 14.5 15.0 15.5 16.0 16.5 17.0 17.5 18.0 18.5 19.0 19.5 20.0 20.5 21.0 21.5 22.0 22.5 23.0 23.5 24.0 24.5 25.0 -20 0 20 40 -20 -19 -18 -17 -16 -15 -14 -13 -12 -11 -10 -9 -8 -7 -6 -5 -4 -3 -2 -1 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 cluster1 cluster2 class -25 -20 -15 -10 -5 0 5 10 15 20 25 30 -20.0 -19.5 -19.0 -18.5 -18.0 -17.5 -17.0 -16.5 -16.0 -15.5 -15.0 -14.5 -14.0 -13.5 -13.0 -12.5 -12.0 -11.5 -11.0 -10.5 -10.0 -9.5 -9.0 -8.5 -8.0 -7.5 -7.0 -6.5 -6.0 -5.5 -5.0 -4.5 -4.0 -3.5 -3.0 -2.5 -2.0 -1.5 -1.0 -0.5 0.0 0.5 1.0 1.5 2.0 2.5 3.0 3.5 4.0 4.5 5.0 5.5 6.0 6.5 7.0 7.5 8.0 8.5 9.0 9.5 10.0 10.5 11.0 11.5 12.0 12.5 13.0 13.5 14.0 14.5 15.0 15.5 16.0 16.5 17.0 17.5 18.0 18.5 19.0 19.5 20.0 20.5 21.0 21.5 22.0 22.5 23.0 23.5 24.0 24.5 25.0 -20 0 20 40 -20 -19 -18 -17 -16 -15 -14 -13 -12 -11 -10 -9 -8 -7 -6 -5 -4 -3 -2 -1 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 y

train DPM using collabsed Gibbs sampling


In [3]:
using BNP

D = 2 # 2 dimensional data
N = 200 # number of data points

# data matrix
X = zeros(D, N)

X[1,:] = convert(Array, Data[:x])
X[2,:] = convert(Array, Data[:y])

# init base distribution parameters
mu0 = vec(mean(X, 2))
kappa0 = 9.0
nu0 = 5.0
Sigma0 = eye(D) * 10

# base distribution and concentration parameter (Gaussian with Normal Inverse Wishart Prior)
H = GaussianWishart(mu0, kappa0, nu0, Sigma0)

# train Dirichlet Process Mixture Model
result = train(DPM(H), Gibbs(), RandomInitialisation(k = 10), X);

visualize inferred models


In [5]:
using Interact

K = zeros(Int, N)

# interactively loop over all iterations
@manipulate for iteration = 1:size(result, 1)
    
    idx = unique(result[iteration].Z)
    
    for n in 1:N
        K[n] = findfirst(idx .== result[iteration].Z[n])
    end

    plot(x = X[1,:], y = X[2,:], color = K)
    
end


Out[5]:
x -25 -20 -15 -10 -5 0 5 10 15 20 25 30 -20.0 -19.5 -19.0 -18.5 -18.0 -17.5 -17.0 -16.5 -16.0 -15.5 -15.0 -14.5 -14.0 -13.5 -13.0 -12.5 -12.0 -11.5 -11.0 -10.5 -10.0 -9.5 -9.0 -8.5 -8.0 -7.5 -7.0 -6.5 -6.0 -5.5 -5.0 -4.5 -4.0 -3.5 -3.0 -2.5 -2.0 -1.5 -1.0 -0.5 0.0 0.5 1.0 1.5 2.0 2.5 3.0 3.5 4.0 4.5 5.0 5.5 6.0 6.5 7.0 7.5 8.0 8.5 9.0 9.5 10.0 10.5 11.0 11.5 12.0 12.5 13.0 13.5 14.0 14.5 15.0 15.5 16.0 16.5 17.0 17.5 18.0 18.5 19.0 19.5 20.0 20.5 21.0 21.5 22.0 22.5 23.0 23.5 24.0 24.5 25.0 -20 0 20 40 -20 -19 -18 -17 -16 -15 -14 -13 -12 -11 -10 -9 -8 -7 -6 -5 -4 -3 -2 -1 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 2.0 1.0 1.5 Color -25 -20 -15 -10 -5 0 5 10 15 20 25 30 -20.0 -19.5 -19.0 -18.5 -18.0 -17.5 -17.0 -16.5 -16.0 -15.5 -15.0 -14.5 -14.0 -13.5 -13.0 -12.5 -12.0 -11.5 -11.0 -10.5 -10.0 -9.5 -9.0 -8.5 -8.0 -7.5 -7.0 -6.5 -6.0 -5.5 -5.0 -4.5 -4.0 -3.5 -3.0 -2.5 -2.0 -1.5 -1.0 -0.5 0.0 0.5 1.0 1.5 2.0 2.5 3.0 3.5 4.0 4.5 5.0 5.5 6.0 6.5 7.0 7.5 8.0 8.5 9.0 9.5 10.0 10.5 11.0 11.5 12.0 12.5 13.0 13.5 14.0 14.5 15.0 15.5 16.0 16.5 17.0 17.5 18.0 18.5 19.0 19.5 20.0 20.5 21.0 21.5 22.0 22.5 23.0 23.5 24.0 24.5 25.0 -20 0 20 40 -20 -19 -18 -17 -16 -15 -14 -13 -12 -11 -10 -9 -8 -7 -6 -5 -4 -3 -2 -1 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 y

visualize energy function


In [6]:
LLH = map(x -> x.energy, result)

# plot
p1 = plot( x = collect(1:size(result, 1)), 
            y = LLH, 
            Geom.line,
            Guide.xlabel("iteration"), 
            Guide.ylabel("log likelihood", orientation=:vertical) )

p2 = plot( x = collect(1:size(result, 1)), 
            y = LLH, 
            Geom.smooth,
            Guide.xlabel("iteration"), 
Guide.ylabel("log likelihood (smoothed)", orientation=:vertical) )

# stack together
vstack(p1, p2)


Out[6]:
iteration -150 -100 -50 0 50 100 150 200 250 -100 -95 -90 -85 -80 -75 -70 -65 -60 -55 -50 -45 -40 -35 -30 -25 -20 -15 -10 -5 0 5 10 15 20 25 30 35 40 45 50 55 60 65 70 75 80 85 90 95 100 105 110 115 120 125 130 135 140 145 150 155 160 165 170 175 180 185 190 195 200 -100 0 100 200 -100 -90 -80 -70 -60 -50 -40 -30 -20 -10 0 10 20 30 40 50 60 70 80 90 100 110 120 130 140 150 160 170 180 190 200 -5.2 -5.0 -4.8 -4.6 -4.4 -4.2 -4.0 -3.8 -3.6 -3.4 -3.2 -3.0 -2.8 -2.6 -2.4 -5.00 -4.95 -4.90 -4.85 -4.80 -4.75 -4.70 -4.65 -4.60 -4.55 -4.50 -4.45 -4.40 -4.35 -4.30 -4.25 -4.20 -4.15 -4.10 -4.05 -4.00 -3.95 -3.90 -3.85 -3.80 -3.75 -3.70 -3.65 -3.60 -3.55 -3.50 -3.45 -3.40 -3.35 -3.30 -3.25 -3.20 -3.15 -3.10 -3.05 -3.00 -2.95 -2.90 -2.85 -2.80 -2.75 -2.70 -2.65 -2.60 -5 -4 -3 -2 -5.00 -4.95 -4.90 -4.85 -4.80 -4.75 -4.70 -4.65 -4.60 -4.55 -4.50 -4.45 -4.40 -4.35 -4.30 -4.25 -4.20 -4.15 -4.10 -4.05 -4.00 -3.95 -3.90 -3.85 -3.80 -3.75 -3.70 -3.65 -3.60 -3.55 -3.50 -3.45 -3.40 -3.35 -3.30 -3.25 -3.20 -3.15 -3.10 -3.05 -3.00 -2.95 -2.90 -2.85 -2.80 -2.75 -2.70 -2.65 -2.60 log likelihood (smoothed) iteration -150 -100 -50 0 50 100 150 200 250 -100 -95 -90 -85 -80 -75 -70 -65 -60 -55 -50 -45 -40 -35 -30 -25 -20 -15 -10 -5 0 5 10 15 20 25 30 35 40 45 50 55 60 65 70 75 80 85 90 95 100 105 110 115 120 125 130 135 140 145 150 155 160 165 170 175 180 185 190 195 200 -100 0 100 200 -100 -90 -80 -70 -60 -50 -40 -30 -20 -10 0 10 20 30 40 50 60 70 80 90 100 110 120 130 140 150 160 170 180 190 200 -5.2 -5.0 -4.8 -4.6 -4.4 -4.2 -4.0 -3.8 -3.6 -3.4 -3.2 -3.0 -2.8 -2.6 -2.4 -5.00 -4.95 -4.90 -4.85 -4.80 -4.75 -4.70 -4.65 -4.60 -4.55 -4.50 -4.45 -4.40 -4.35 -4.30 -4.25 -4.20 -4.15 -4.10 -4.05 -4.00 -3.95 -3.90 -3.85 -3.80 -3.75 -3.70 -3.65 -3.60 -3.55 -3.50 -3.45 -3.40 -3.35 -3.30 -3.25 -3.20 -3.15 -3.10 -3.05 -3.00 -2.95 -2.90 -2.85 -2.80 -2.75 -2.70 -2.65 -2.60 -5 -4 -3 -2 -5.00 -4.95 -4.90 -4.85 -4.80 -4.75 -4.70 -4.65 -4.60 -4.55 -4.50 -4.45 -4.40 -4.35 -4.30 -4.25 -4.20 -4.15 -4.10 -4.05 -4.00 -3.95 -3.90 -3.85 -3.80 -3.75 -3.70 -3.65 -3.60 -3.55 -3.50 -3.45 -3.40 -3.35 -3.30 -3.25 -3.20 -3.15 -3.10 -3.05 -3.00 -2.95 -2.90 -2.85 -2.80 -2.75 -2.70 -2.65 -2.60 log likelihood

visualize further information


In [7]:
# number of clusters
C = [length(unique(x.Z)) for x in result]

# alpha parameter
A = map(x -> x.α, result)

# plot
p1 = plot( x = collect(1:size(result, 1)), 
            y = C, Geom.line, 
            Guide.xlabel("iteration"), 
            Guide.ylabel("number of clusters", orientation=:vertical) )
p2 = plot( x = collect(1:size(result, 1)), 
            y = A, 
            Geom.line, 
            Guide.xlabel("iteration"), 
            Guide.ylabel("alpha", orientation=:vertical) )

# stack together
vstack(p1, p2)


Out[7]:
iteration -150 -100 -50 0 50 100 150 200 250 -100 -95 -90 -85 -80 -75 -70 -65 -60 -55 -50 -45 -40 -35 -30 -25 -20 -15 -10 -5 0 5 10 15 20 25 30 35 40 45 50 55 60 65 70 75 80 85 90 95 100 105 110 115 120 125 130 135 140 145 150 155 160 165 170 175 180 185 190 195 200 -100 0 100 200 -100 -90 -80 -70 -60 -50 -40 -30 -20 -10 0 10 20 30 40 50 60 70 80 90 100 110 120 130 140 150 160 170 180 190 200 -5 -4 -3 -2 -1 0 1 2 3 4 5 6 7 8 9 -4.0 -3.8 -3.6 -3.4 -3.2 -3.0 -2.8 -2.6 -2.4 -2.2 -2.0 -1.8 -1.6 -1.4 -1.2 -1.0 -0.8 -0.6 -0.4 -0.2 0.0 0.2 0.4 0.6 0.8 1.0 1.2 1.4 1.6 1.8 2.0 2.2 2.4 2.6 2.8 3.0 3.2 3.4 3.6 3.8 4.0 4.2 4.4 4.6 4.8 5.0 5.2 5.4 5.6 5.8 6.0 6.2 6.4 6.6 6.8 7.0 7.2 7.4 7.6 7.8 8.0 -5 0 5 10 -4.0 -3.5 -3.0 -2.5 -2.0 -1.5 -1.0 -0.5 0.0 0.5 1.0 1.5 2.0 2.5 3.0 3.5 4.0 4.5 5.0 5.5 6.0 6.5 7.0 7.5 8.0 alpha iteration -150 -100 -50 0 50 100 150 200 250 -100 -95 -90 -85 -80 -75 -70 -65 -60 -55 -50 -45 -40 -35 -30 -25 -20 -15 -10 -5 0 5 10 15 20 25 30 35 40 45 50 55 60 65 70 75 80 85 90 95 100 105 110 115 120 125 130 135 140 145 150 155 160 165 170 175 180 185 190 195 200 -100 0 100 200 -100 -90 -80 -70 -60 -50 -40 -30 -20 -10 0 10 20 30 40 50 60 70 80 90 100 110 120 130 140 150 160 170 180 190 200 -20 -15 -10 -5 0 5 10 15 20 25 30 35 -15.0 -14.5 -14.0 -13.5 -13.0 -12.5 -12.0 -11.5 -11.0 -10.5 -10.0 -9.5 -9.0 -8.5 -8.0 -7.5 -7.0 -6.5 -6.0 -5.5 -5.0 -4.5 -4.0 -3.5 -3.0 -2.5 -2.0 -1.5 -1.0 -0.5 0.0 0.5 1.0 1.5 2.0 2.5 3.0 3.5 4.0 4.5 5.0 5.5 6.0 6.5 7.0 7.5 8.0 8.5 9.0 9.5 10.0 10.5 11.0 11.5 12.0 12.5 13.0 13.5 14.0 14.5 15.0 15.5 16.0 16.5 17.0 17.5 18.0 18.5 19.0 19.5 20.0 20.5 21.0 21.5 22.0 22.5 23.0 23.5 24.0 24.5 25.0 25.5 26.0 26.5 27.0 27.5 28.0 28.5 29.0 29.5 30.0 -20 0 20 40 -15 -14 -13 -12 -11 -10 -9 -8 -7 -6 -5 -4 -3 -2 -1 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 number of clusters